package com.yhx.security.gfoBank.api.com.gmcrypto.mf.gm;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.SecureRandom;

import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1Encoding;
import org.bouncycastle.asn1.ASN1InputStream;
import org.bouncycastle.asn1.ASN1Integer;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.DEROctetString;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.gm.GMNamedCurves;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.sec.ECPrivateKey;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoException;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.crypto.KeyGenerationParameters;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.crypto.generators.ECKeyPairGenerator;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECKeyGenerationParameters;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.crypto.params.ParametersWithID;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.crypto.signers.SM2Signer;
import org.bouncycastle.math.ec.ECConstants;

/**
 sm2算法工具类，提供加解密、签名验签功能
 * <p>因为数据加解密都是对字节数据加解密，因此需要注意加密前和解密后的字符集保持一致
 * <p>若无特殊说明，接口接收的都是原始的二进制数据，被hex或者base64编码的数据，务必解码之后再传进来
 * @author liangruxing-yfzx
 *
 */
public class SM2Util {

	
	private static final X9ECParameters sm2p256v1 = GMNamedCurves.getByName("sm2p256v1");
	private static final int SM3_DIGEST_LENGTH_32 = 32;
	private static final byte[] defaultUserID = "1234567812345678".getBytes();


	/**
	 * sm2解密
	 * @param publicKey 公钥，二进制数据，若被编码的，请解码再传入，如被hex编码，则hex解码后再传进来
	 * @param data 待加密的数据
	 * @return 返回der编码的密文数据
	 * @throws InvalidCipherTextException
	 * @throws IOException
	 */
	public static byte[] encrypt(byte[] publicKey, byte[] data) throws InvalidCipherTextException, IOException
    {
		if(publicKey.length == 64) {
			byte tmp[] = new byte[65];
			System.arraycopy(publicKey, 0, tmp, 1, publicKey.length);
			tmp[0] = 0x04;
			publicKey = tmp;
		}
	    ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
	    ECPublicKeyParameters pubKeyParameters = new ECPublicKeyParameters(sm2p256v1.getCurve().decodePoint(publicKey), parameters);
	    SM2Engine engine = new SM2Engine();
	    ParametersWithRandom pwr = new ParametersWithRandom(pubKeyParameters, new SecureRandom());
	    engine.init(true, pwr);
	    byte[] cipher = engine.processBlock(data, 0, data.length);
	    return encodeSM2CipherToDER(cipher);
    }

	/**
	 * sm2解密
	 * @param privateKey 私钥，二进制数据，若被编码的，请解码再传入，如被hex编码，则hex解码后再传进来
	 * @param encryptedData 密文，二进制数据
	 * @return 返回解密后的数据
	 * @throws InvalidCipherTextException
	 * @throws IOException
	 */
	public static byte[] decrypt(byte[] privateKey, byte[] encryptedData) throws InvalidCipherTextException, IOException
	{
	    ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
	
	    ECPrivateKeyParameters priKeyParameters = new ECPrivateKeyParameters(new BigInteger(1,privateKey),parameters);
	    SM2Engine engine = new SM2Engine();
	    engine.init(false, priKeyParameters);
	    byte[] cipher = decodeDERSM2Cipher(encryptedData);
	    return engine.processBlock(cipher, 0, cipher.length);
	}

	/**
	 * sm2签名
	 * <p>userId使用默认：1234567812345678
	 * @param privateKey 私钥，二进制数据
	 * @param sourceData 待签名数据
	 * @return 返回der编码的签名值
	 * @throws CryptoException
	 */
	public static byte[] sign(byte[] privateKey, byte[] sourceData) throws CryptoException{
		return sign(defaultUserID, privateKey, sourceData);
	}

	/**
	 * sm2签名
	 * @param userId ID值，若无约定，使用默认：1234567812345678
	 * @param privateKey 私钥，二进制数据
	 * @param sourceData 待签名数据
	 * @return 返回der编码的签名值
	 * @throws CryptoException
	 */
	public static byte[] sign(byte[] userId, byte[] privateKey, byte[] sourceData) throws CryptoException {
		
	    ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
	    ECPrivateKeyParameters priKeyParameters = new ECPrivateKeyParameters(new BigInteger(1,privateKey),parameters);
		SM2Signer signer = new SM2Signer();
        CipherParameters param = null;
        ParametersWithRandom pwr = new ParametersWithRandom(priKeyParameters, new SecureRandom());
        if (userId != null) {
            param = new ParametersWithID(pwr, userId);
        } else {
            param = pwr;
        }
        signer.init(true, param);
        signer.update(sourceData, 0, sourceData.length);
        return signer.generateSignature();
	}

	/**
	 * sm2验签
	 * <p>userId使用默认：1234567812345678
	 * @param publicKey 公钥，二进制数据
	 * @param sourceData 待验签数据
	 * @param signData 签名值
	 * @return 返回是否成功
	 */
	public static boolean verifySign(byte[] publicKey, byte[] sourceData, byte[] signData){
		return verifySign(defaultUserID, publicKey, sourceData, signData);
	}

	/**
	 * sm2验签
	 * @param userId ID值，若无约定，使用默认：1234567812345678
	 * @param publicKey 公钥，二进制数据
	 * @param sourceData 待验签数据
	 * @param signData 签名值
	 * @return 返回是否成功
	 */
	public static boolean verifySign(byte[] userId, byte[] publicKey, byte[] sourceData, byte[] signData) {
		
		if(publicKey.length == 64) {
			byte tmp[] = new byte[65];
			System.arraycopy(publicKey, 0, tmp, 1, publicKey.length);
			tmp[0] = 0x04;
			publicKey = tmp;
		}
		
		ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
	    ECPublicKeyParameters pubKeyParameters = new ECPublicKeyParameters(sm2p256v1.getCurve().decodePoint(publicKey), parameters);
		SM2Signer signer = new SM2Signer();
        CipherParameters param;
        if (userId != null) {
            param = new ParametersWithID(pubKeyParameters, userId);
        } else {
            param = pubKeyParameters;
        }
        signer.init(false, param);
        signer.update(sourceData, 0, sourceData.length);
        return signer.verifySignature(signData);
	}

	/**
	 * 获取der编码下的公钥
	 * @param derData der编码的公钥，二进制数据
	 * @return 返回公钥值
	 */
	public static byte[] getPublicKey(byte[] derData)
    {

        SubjectPublicKeyInfo info = SubjectPublicKeyInfo.getInstance(derData);
        return info.getPublicKeyData().getBytes();
    }

	/**
	 * 获取der编码下的私钥
	 * @param derData der编码的私钥，二进制数据
	 * @return 返回私钥值
	 * @throws IOException
	 */
    public static byte[] getPrivateKey(byte[] derData) throws IOException
    {

        PrivateKeyInfo pinfo = PrivateKeyInfo.getInstance(derData);
        ECPrivateKey cpk = ECPrivateKey.getInstance(pinfo.parsePrivateKey());

        int length = 32;
        byte[] bytes = cpk.getKey().toByteArray();
        if (bytes.length == length)
        {
            return bytes;
        }

        int start = bytes[0] == 0 ? 1 : 0;
        int count = bytes.length - start;

        if (count > length)
        {
            return null;
        }

        byte[] tmp = new byte[length];
        System.arraycopy(bytes, start, tmp, tmp.length - count, count);
        return tmp;
    }

	/**
	 * 生成sm2公私钥对
	 * @return
	 */
	public static SM2KeyPair generateKeyPair(){
		ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
	    KeyGenerationParameters kgp = new ECKeyGenerationParameters(parameters, new SecureRandom());
	    ECKeyPairGenerator ecKeyPairGenerator = new ECKeyPairGenerator();
		ecKeyPairGenerator.init(kgp);
		
		ECPrivateKeyParameters ecpriv = null;
		ECPublicKeyParameters ecpub = null;
//		int count = 0;
		do {
			AsymmetricCipherKeyPair keypair = ecKeyPairGenerator.generateKeyPair();
			ecpriv = (ECPrivateKeyParameters) keypair.getPrivate();
			ecpub = (ECPublicKeyParameters) keypair.getPublic();
		}while(ecpriv == null || ecpriv.getD().equals(ECConstants.ZERO) 
				|| ecpriv.getD().compareTo(sm2p256v1.getN()) >= 0 || ecpriv.getD().signum() <= 0);

        byte[] privKey = formartBigNum(ecpriv.getD(),32);
        byte[] pubxKey = formartBigNum(ecpub.getQ().getAffineXCoord().toBigInteger(), 32);
        byte[] pubyKey = formartBigNum(ecpub.getQ().getAffineYCoord().toBigInteger(), 32);
        byte[] pubKey = new byte[64];
        System.arraycopy(pubxKey, 0, pubKey, 0, pubxKey.length);
        System.arraycopy(pubyKey, 0, pubKey, pubxKey.length, pubyKey.length)
        ;
        return new SM2KeyPair(privKey, pubKey);
    }

	/**
	 * 将c1c2c3密文转成der编码
	 * @param cipher c1c2c3密文
	 * @return der编码的sm2密文
	 * @throws IOException
	 */
    public static byte[] encodeSM2CipherToDER(byte[] cipher)
        throws IOException {
        int startPos = 1;
        int curveLength = (sm2p256v1.getCurve().getFieldSize() + 7) / 8;
        int digestLength = SM3_DIGEST_LENGTH_32;
        
        byte[] c1x = new byte[curveLength];
        System.arraycopy(cipher, startPos, c1x, 0, c1x.length);
        startPos += c1x.length;

        byte[] c1y = new byte[curveLength];
        System.arraycopy(cipher, startPos, c1y, 0, c1y.length);
        startPos += c1y.length;

        byte[] c2 = new byte[cipher.length - c1x.length - c1y.length - 1 - digestLength];
        System.arraycopy(cipher, startPos, c2, 0, c2.length);
        startPos += c2.length;

        byte[] c3 = new byte[digestLength];
        System.arraycopy(cipher, startPos, c3, 0, c3.length);

        ASN1Encodable[] arr = new ASN1Encodable[4];
        arr[0] = new ASN1Integer(new BigInteger(1, c1x));
        if(new BigInteger(1, c1x).toByteArray().length<32){
        	System.out.println("");
        }
        arr[1] = new ASN1Integer(new BigInteger(1,c1y));
        arr[2] = new DEROctetString(c3);
        arr[3] = new DEROctetString(c2);
        DERSequence ds = new DERSequence(arr);
        return ds.getEncoded(ASN1Encoding.DER);
    }

	/**
	 * 解DER编码密文（根据《SM2密码算法使用规范》 GM/T 0009-2012）
	 *
	 * @param derCipher 将der编码的sm2密文转成c1c2c3格式
	 * @return 返回c1c2c3格式密文
	 * @throws IOException
	 */
    public static byte[] decodeDERSM2Cipher(byte[] derCipher) throws IOException {
    	
    	ByteArrayInputStream bis = new ByteArrayInputStream(derCipher);
    	ASN1InputStream dis = new ASN1InputStream(bis);
//        ASN1Sequence as = DERSequence.getInstance(derCipher);
    	ASN1Sequence as = (ASN1Sequence) dis.readObject();
        byte[] c1x = ((ASN1Integer) as.getObjectAt(0)).getValue().toByteArray();
        byte[] c1y = ((ASN1Integer) as.getObjectAt(1)).getValue().toByteArray();
        byte[] c3 = ((DEROctetString) as.getObjectAt(2)).getOctets();
        byte[] c2 = ((DEROctetString) as.getObjectAt(3)).getOctets();
        dis.close();

        int pos = 0;
        int curveLength = (sm2p256v1.getCurve().getFieldSize() + 7) / 8;
        byte[] cipherText = new byte[1 + curveLength*2 + c2.length + c3.length];

        final byte uncompressedFlag = 0x04;
        cipherText[0] = uncompressedFlag;
        pos += 1;

        if (c1x.length >= curveLength) {
        	System.arraycopy(c1x, c1x.length-curveLength, cipherText, pos, curveLength);
        } else {
        	System.arraycopy(c1x, 0, cipherText, pos+curveLength-c1x.length, c1x.length);
        }
        pos += curveLength;
        
        if (c1y.length >= curveLength) {
        	System.arraycopy(c1y, c1y.length-curveLength, cipherText, pos, curveLength);
        } else {
        	System.arraycopy(c1y, 0, cipherText, pos+curveLength-c1y.length, c1y.length);
        }
        pos += curveLength;

        System.arraycopy(c2, 0, cipherText, pos, c2.length);
        pos += c2.length;

        System.arraycopy(c3, 0, cipherText, pos, c3.length);

        return cipherText;
    }

	/**
	 * 格式化BigInteger，bg.toByteArray()获取到的字节数据长度不固定，因此需要格式化为固定长度
	 * @param bg 大数
	 * @param needLength 所需要的长度
	 * @return
	 */
	private static byte[] formartBigNum(BigInteger bg,int needLength) {
		
		byte[] tmp = new byte[needLength];
		byte[] bgByte = bg.toByteArray();
		if(bgByte == null) {
			return null;
		}
		
		if(bgByte.length > needLength) {
			System.arraycopy(bgByte, bgByte.length - needLength, tmp, 0, needLength);
		}else if(bgByte.length == needLength) {
			tmp = bgByte;
		}else {
			System.arraycopy(bgByte, 0, tmp, needLength - bgByte.length, bgByte.length);
		}
		
		return tmp;
	}
	
}
